package org.zjvis.datascience.common.algopy;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.zjvis.datascience.common.constant.DataJsonConstant;
import org.zjvis.datascience.common.dto.TaskDTO;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.enums.AlgPyEnum;
import org.zjvis.datascience.common.util.task.TaskDTOUtil;
import org.zjvis.datascience.common.util.task.TaskInstanceDTOUtil;
import org.zjvis.datascience.common.util.ToolUtil;

import java.util.Iterator;
import java.util.List;

import static org.zjvis.datascience.common.constant.DataJsonConstant.INSTANCE_SET_PARAM_HEADER;
import static org.zjvis.datascience.common.constant.DataJsonConstant.TASK_ALG_TYPE;

/**
 * @description 特征工程类算子辅助工具类
 * @date 2021-12-24
 */
public class ImputationUtil {

    private static final String FORMITEM_IDX = "formItem";

    private static final String FORMDATA_IDX = "formData";

    private static final String PROPS_IDX = "props";

    private static final String OPTIONS_IDX = "options";

    private static final String LABEL_IDX = "label";

    private static final String VALUE_IDX = "value";

    private static final String COLS_IDX = "cols";

    private static final String SINGLE_COL_IDX = "col";

    private static final String NUMBER_STR = "number";

    public static void initParams(TaskInstanceDTO instance) {
        initParams(instance, null);
    }

    private static Integer getColsSavedIndex(JSONObject jsonObject) {
        int index = 0;
        int algType = jsonObject.getInteger(TASK_ALG_TYPE);
        if (algType == AlgPyEnum.FEATURE_SCALING.getVal() ||
                algType == AlgPyEnum.ANOMALY_STAT.getVal() ||
                algType == AlgPyEnum.FEATURE_SMOOTHING.getVal()) {
            index = 1;
        } else if (algType == AlgPyEnum.TIMESERIES_DECOMPOSE.getVal() ||
                algType == AlgPyEnum.IMPUTATION_STAT.getVal() ||
                algType == AlgPyEnum.IMPUTATION_MULTI.getVal() ||
                algType == AlgPyEnum.ANOMALY_KNN.getVal()) {
            index = 0;
        }
        return index;
    }

    /**
     * @param instance
     * @param taskDTO
     */
    public static void initParams(TaskInstanceDTO instance, TaskDTO taskDTO) {
        JSONObject instanceDataJson = TaskInstanceDTOUtil.getInstanceDataJson(instance);
        JSONObject input = TaskInstanceDTOUtil.getJsonArrayInJson(instanceDataJson, DataJsonConstant.INPUT_HEADER).getJSONObject(0);
        List<String> tableCols = input.getJSONArray(DataJsonConstant.TABLE_COLS_IDS).toJavaList(String.class);
        List<String> colTypes = input.getJSONArray(DataJsonConstant.COL_TYPES_IDS).toJavaList(String.class);

        try {
            List<String> columnsFiltered = ToolUtil.filterTypeAndCol(tableCols, colTypes, NUMBER_STR).getKey();

            JSONObject setParams = instanceDataJson.getJSONArray(INSTANCE_SET_PARAM_HEADER).getJSONObject(0);
            JSONArray colsArray = new JSONArray();
            columnsFiltered.forEach(col -> {
                JSONObject temp = new JSONObject();
                temp.put(LABEL_IDX, col);
                temp.put(VALUE_IDX, col);
                colsArray.add(temp);
            });

            setParams.getJSONArray(FORMITEM_IDX).getJSONObject(getColsSavedIndex(instanceDataJson))
                    .getJSONObject(PROPS_IDX).put(OPTIONS_IDX, colsArray);

            JSONObject formData = setParams.getJSONObject(FORMDATA_IDX);
            String colFormData = formData.getString(SINGLE_COL_IDX);
            if (colFormData != null) {
                if (colFormData.equals("<无可选特征列>")) {
                    if (columnsFiltered.size() != 0) {
                        formData.put(SINGLE_COL_IDX, columnsFiltered.get(0));
                    } else {
                        formData.put(SINGLE_COL_IDX, "<请确保父节点有数值型特征列>");
                    }
                } else {
                    if (!columnsFiltered.contains(colFormData)) {
                        formData.put(SINGLE_COL_IDX, "<请重新选择特征列>");
                    }
                }
            }
            List<String> cols1 = null;
            if (formData.containsKey(COLS_IDX)) {
                cols1 = formData.getJSONArray(COLS_IDX).toJavaList(String.class);
                cols1.retainAll(columnsFiltered);
            }

            List<String> selectedCols = null;
            if (null != taskDTO && cols1 != null) {
                JSONArray valueByKey = (JSONArray) TaskDTOUtil.getValueByKey("setParams[0].formData.cols", taskDTO);
                selectedCols = valueByKey.toJavaList(String.class);
                cols1.retainAll(selectedCols);
                formData.put(COLS_IDX, JSONArray.parseArray(JSON.toJSONString(cols1)));
            }

            JSONArray newsetParamsJsonArray = new JSONArray();
            newsetParamsJsonArray.add(setParams);
            instanceDataJson.put(INSTANCE_SET_PARAM_HEADER, newsetParamsJsonArray);

            TaskInstanceDTOUtil.updateInstanceDataJson(instance, instanceDataJson);

        } catch (Exception e) {
            //直接设置 初始化 错误标志位
            TaskInstanceDTOUtil.setErrorFlagInJson(instance);
        }
    }


    /**
     * using self input info to update params.
     * 这个方法有点问题  它只是根据input中 删除了不合法的 但是没有添加其他可以用的。
     *
     * @param instance
     */
    public static void updateParams(TaskInstanceDTO instance) {
        JSONObject instanceDataJson = TaskInstanceDTOUtil.getInstanceDataJson(instance);
        JSONObject input = TaskInstanceDTOUtil.getJsonArrayInJson(instanceDataJson, DataJsonConstant.INPUT_HEADER).getJSONObject(0);
        List<String> tableCols = input.getJSONArray(DataJsonConstant.TABLE_COLS_IDS).toJavaList(String.class);
        JSONObject paramInfo = TaskInstanceDTOUtil.getJsonArrayInJson(instanceDataJson, INSTANCE_SET_PARAM_HEADER).getJSONObject(0);

        Iterator<Object> optionsIterator = paramInfo.getJSONArray(FORMITEM_IDX).getJSONObject(0)
                .getJSONObject(PROPS_IDX).getJSONArray(OPTIONS_IDX).iterator();

        Iterator<Object> colsIterator = paramInfo.getJSONObject(FORMDATA_IDX).getJSONArray(COLS_IDX).iterator();
        while (optionsIterator.hasNext()) {
            JSONObject obj = (JSONObject) optionsIterator.next();
            String label = obj.getString(LABEL_IDX);
            if (!tableCols.contains(label)) {
                optionsIterator.remove();
            }
        }
        while (colsIterator.hasNext()) {
            String col = (String) colsIterator.next();
            if (!tableCols.contains(col)) {
                colsIterator.remove();
            }
        }

        TaskInstanceDTOUtil.updateInstanceDataJson(instance, instanceDataJson);
    }

    /**
     * sync task datajson with instance's param setting
     *
     * @param instance
     * @param taskDTO  //背景 之前 ALGOPY 只更新了 task结果 没有更新 instance
     */
    public static void syncParams(TaskInstanceDTO instance, TaskDTO taskDTO) {
        JSONObject paramInfo = TaskInstanceDTOUtil.getJsonArrayInJson(instance, INSTANCE_SET_PARAM_HEADER).getJSONObject(0);
//        JSONObject paramInfo = TaskInstanceDTOUtil.getJsonObjectInJson(instance, "setParams[0]"); //作用一样
        JSONArray array = new JSONArray();
        array.add(paramInfo);
        TaskDTOUtil.updateJsonWithSpecificKey(taskDTO, INSTANCE_SET_PARAM_HEADER, array); //用instance更新task 的 setParams
        Object valueFromTask = TaskDTOUtil.getValueByKey(DataJsonConstant.OUTPUT_HEADER, taskDTO);
        TaskInstanceDTOUtil.updateJsonWithSpecificKey(instance, DataJsonConstant.OUTPUT_HEADER, valueFromTask); //用task的结果 更新 instance 的输出结果
    }


}
